/*
 *
 * Copyright 2022 tars authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */
package main

import (
	"fmt"
	"strings"

	"google.golang.org/protobuf/compiler/protogen"
)

// Paths for packages used by code generated in this file,
// relative to the import_prefix of the generator.Generator.
const (
	version         = "1.0.0"
	protoPackage    = protogen.GoImportPath("google.golang.org/protobuf/proto")
	modelPackage    = protogen.GoImportPath("github.com/TarsCloud/TarsGo/tars/model")
	requestfPackage = protogen.GoImportPath("github.com/TarsCloud/TarsGo/tars/protocol/res/requestf")
	tarsPackage     = protogen.GoImportPath("github.com/TarsCloud/TarsGo/tars")
	toolsPackage    = protogen.GoImportPath("github.com/TarsCloud/TarsGo/tars/util/tools")
	currentPackage  = protogen.GoImportPath("github.com/TarsCloud/TarsGo/tars/util/current")
	contextPackage  = protogen.GoImportPath("context")
	fmtPackage      = protogen.GoImportPath("fmt")
)

// GenerateFile generates a _tars.pb.go file containing tarsrpc service definitions.
func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.GeneratedFile {
	if len(file.Services) == 0 {
		return nil
	}
	filename := file.GeneratedFilenamePrefix + "_tars.pb.go"
	g := gen.NewGeneratedFile(filename, file.GoImportPath)
	g.P("// Code generated by protoc-gen-go-tarsrpc. DO NOT EDIT.")
	g.P("// versions:")
	g.P("// - protoc-gen-go-tarsrpc v", version)
	g.P("// - protoc             ", protocVersion(gen))
	if file.Proto.GetOptions().GetDeprecated() {
		g.P("// ", file.Desc.Path(), " is a deprecated file.")
	} else {
		g.P("// source: ", file.Desc.Path())
	}
	g.P()
	g.P("package ", file.GoPackageName)
	g.P()
	t := tarsrpc{}
	t.Init(g)
	t.Generate(file)
	return g
}

func protocVersion(gen *protogen.Plugin) string {
	v := gen.Request.GetCompilerVersion()
	if v == nil {
		return "(unknown)"
	}
	var suffix string
	if s := v.GetSuffix(); s != "" {
		suffix = "-" + s
	}
	return fmt.Sprintf("v%d.%d.%d%s", v.GetMajor(), v.GetMinor(), v.GetPatch(), suffix)
}

// tarsrpc is an implementation of the Go protocol buffer compiler's
// plugin architecture.  It generates bindings for tars rpc support.
type tarsrpc struct {
	gen *protogen.GeneratedFile
}

//Name returns the name of this plugin
func (t *tarsrpc) Name() string {
	return "tarsrpc"
}

//Init initializes the plugin.
func (t *tarsrpc) Init(gen *protogen.GeneratedFile) {
	t.gen = gen
}

// upperFirstLatter make the first charter of given string  upper class
func upperFirstLatter(s string) string {
	if len(s) == 0 {
		return ""
	}
	if len(s) == 1 {
		return strings.ToUpper(string(s[0]))
	}
	return strings.ToUpper(string(s[0])) + s[1:]
}

// GenerateImports generates the import declaration for this file.
func (t *tarsrpc) GenerateImports(file *protogen.File) {
	t.gen.QualifiedGoIdent(protoPackage.Ident("Marshal"))
	t.gen.QualifiedGoIdent(modelPackage.Ident("Servant"))
	t.gen.QualifiedGoIdent(requestfPackage.Ident("ResponsePacket"))
	t.gen.QualifiedGoIdent(tarsPackage.Ident("AddServant"))
	t.gen.QualifiedGoIdent(toolsPackage.Ident("Int8ToByte"))
	t.gen.QualifiedGoIdent(contextPackage.Ident("Context"))
	t.gen.QualifiedGoIdent(currentPackage.Ident("GetResponseContext"))
	t.gen.QualifiedGoIdent(fmtPackage.Ident("Errorf"))
}

// P forwards to g.gen.P.
func (t *tarsrpc) P(args ...interface{}) { t.gen.P(args...) }

// Generate generates code for the services in the given file.
func (t *tarsrpc) Generate(file *protogen.File) {
	if len(file.Services) == 0 {
		return
	}

	t.GenerateImports(file)
	for i, service := range file.Services {
		t.generateService(file, service, i)
	}
}

// generateService generates all the code for the named service
func (t *tarsrpc) generateService(file *protogen.File, service *protogen.Service, index int) {
	originServiceName := service.GoName
	serviceName := upperFirstLatter(originServiceName)
	t.P("// This following code was generated by tarsrpc")
	t.P(fmt.Sprintf("// Gernerated from %s", file.GeneratedFilenamePrefix))
	t.P(fmt.Sprintf(`type  %s struct {
		servant model.Servant
	}
	`, serviceName))
	t.P()

	// generate SetServant
	t.P(fmt.Sprintf(`// SetServant is required by the servant interface.
	func (obj *%s) SetServant(servant model.Servant){
		obj.servant = servant
	}
	`, serviceName))
	t.P()
	// generate AddServant
	t.P(fmt.Sprintf(`// AddServant is required by the servant interface
	func (obj *%s) AddServant(imp PB%sServant, objStr string){
		tars.AddServant(obj, imp, objStr)
	}`, serviceName, serviceName))

	// generate AddServantWithContext
	t.P(fmt.Sprintf(`// AddServantWithContext adds servant  for the service with context
	func (obj *%s) AddServantWithContext(imp PB%sServantWithContext, objStr string) {
		tars.AddServantWithContext(obj, imp, objStr)
	}`, serviceName, serviceName))
	t.P()

	// generate TarsSetTimeout
	t.P(fmt.Sprintf(`// TarsSetTimeout is required by the servant interface. t is the timeout in ms. 
	func (obj *%s) TarsSetTimeout(t int){
		obj.servant.TarsSetTimeout(t)
	}
	`, serviceName))
	t.P()

	// generate TarsSetProtocol
	t.P(fmt.Sprintf(`// TarsSetProtocol is required by the servant interface. t is the protocol. 
	func (obj *%s) TarsSetProtocol(p model.Protocol){
		obj.servant.TarsSetProtocol(p)
	}
	`, serviceName))
	t.P()

	// generate the interface
	t.P(fmt.Sprintf("type PB%sServant interface{", serviceName))
	for _, method := range service.Methods {
		t.P(fmt.Sprintf("%s (input %s) (output %s, err error)",
			upperFirstLatter(method.GoName), t.gen.QualifiedGoIdent(method.Input.GoIdent), t.gen.QualifiedGoIdent(method.Output.GoIdent)))
	}
	t.P("}")
	t.P()

	// generate the context interface
	t.P(fmt.Sprintf("type PB%sServantWithContext interface{", serviceName))
	for _, method := range service.Methods {
		t.P(fmt.Sprintf("%s (ctx context.Context, input %s) (output %s, err error)",
			upperFirstLatter(method.GoName), t.gen.QualifiedGoIdent(method.Input.GoIdent), t.gen.QualifiedGoIdent(method.Output.GoIdent)))
	}
	t.P("}")
	t.P()

	// generate the dispatcher
	t.generateDispatch(service)

	for _, method := range service.Methods {
		t.generateClientCode(service, method)
	}
}
func (t *tarsrpc) generateClientCode(service *protogen.Service, method *protogen.Method) {
	methodName := upperFirstLatter(method.GoName)
	serviceName := upperFirstLatter(service.GoName)
	inType := t.gen.QualifiedGoIdent(method.Input.GoIdent)
	outType := t.gen.QualifiedGoIdent(method.Output.GoIdent)
	t.P(fmt.Sprintf(`// %s is client rpc method as defined
		func (obj *%s) %s(input %s, opts ...map[string]string)(output %s, err error){
			ctx := context.Background()
			return obj.%sWithContext(ctx, input, opts...)
		}
	`, methodName, serviceName, methodName, inType, outType, methodName))

	t.P(fmt.Sprintf(`// %sWithContext is client rpc method as defined
		func (obj *%s) %sWithContext(ctx context.Context, input %s, opts ...map[string]string)(output %s, err error){
			var inputMarshal []byte
			inputMarshal, err = proto.Marshal(&input)
			if err != nil {
				return output, err
			}

			var statusMap map[string]string
			var contextMap map[string]string
			if len(opts) == 1 {
				contextMap = opts[0]
			} else if len(opts) == 2 {
				contextMap = opts[0]
				statusMap = opts[1]
			}

			resp := new(requestf.ResponsePacket)
			err = obj.servant.TarsInvoke(ctx, 0, "%s", inputMarshal, statusMap, contextMap, resp)
			if err != nil {
				return output, err
			}
			if err = proto.Unmarshal(tools.Int8ToByte(resp.SBuffer), &output); err != nil{
				return output, err
			}

			if len(opts) == 1 {
				for k := range contextMap {
					delete(contextMap, k)
				}
				for k, v := range resp.Context {
					contextMap[k] = v
				}
			} else if len(opts) == 2 {
				for k := range contextMap {
					delete(contextMap, k)
				}
				for k, v := range resp.Context {
					contextMap[k] = v
				}
				for k := range statusMap {
					delete(statusMap, k)
				}
				for k, v := range resp.Status {
					statusMap[k] = v
				}
			}
			return output, nil
		}
	`, methodName, serviceName, methodName, inType, outType, method.Desc.Name()))
}
func (t *tarsrpc) generateDispatch(service *protogen.Service) {
	serviceName := upperFirstLatter(service.GoName)
	t.P(fmt.Sprintf(`// Dispatch is used to call the user implement of the defined method.
	func (obj *%s) Dispatch(ctx context.Context, val interface{}, req *requestf.RequestPacket, resp *requestf.ResponsePacket, withContext bool)(err error){
		input := tools.Int8ToByte(req.SBuffer)
		var output []byte
		funcName := req.SFuncName
		switch funcName {
	`, serviceName))
	for _, method := range service.Methods {
		t.P(fmt.Sprintf(`case "%s":
			inputDefine := %s{}
			if err = proto.Unmarshal(input,&inputDefine); err != nil{
				return err
			}
			var res %s
            if !withContext {
				imp := val.(PB%sServant)
				res, err = imp.%s(inputDefine)
				if err != nil {
					return err
				}
			}else {
				imp := val.(PB%sServantWithContext)
				res, err = imp.%s(ctx, inputDefine)
				if err != nil {
					return err
				}
			}
			output , err = proto.Marshal(&res)
			if err != nil {
				return err
			}
		`, method.Desc.Name(), t.gen.QualifiedGoIdent(method.Input.GoIdent), t.gen.QualifiedGoIdent(method.Output.GoIdent), serviceName, upperFirstLatter(method.GoName), serviceName, upperFirstLatter(method.GoName)))
	}
	t.P(`default:
			return fmt.Errorf("func mismatch")
	}
	var statusMap map[string]string
	if status, ok := current.GetResponseStatus(ctx); ok && status != nil {
		statusMap = status
	}
	var contextMap map[string]string
	if ctx, ok := current.GetResponseContext(ctx); ok && ctx != nil {
		contextMap = ctx
	}
	*resp = requestf.ResponsePacket{
		IVersion:     1,
		CPacketType:  0,
		IRequestId:   req.IRequestId,
		IMessageType: 0,
		IRet:         0,
		SBuffer:      tools.ByteToInt8(output),
		Status:       statusMap,
		SResultDesc:  "",
		Context:      contextMap,
	}
	return nil
}
	`)
	t.P()
}
